{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sb_auto_header",
    "tags": [
     "sb_auto_header"
    ]
   },
   "source": [
    "<!-- This cell is automatically updated by tools/tutorial-cell-updater.py -->\n",
    "<!-- The contents are initialized from tutorials/notebook-header.md -->\n",
    "\n",
    "[<img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>](https://colab.research.google.com/github/speechbrain/speechbrain/blob/develop/docs/tutorials/tasks/source-separation.ipynb)\n",
    "to execute or view/download this notebook on\n",
    "[GitHub](https://github.com/speechbrain/speechbrain/tree/develop/docs/tutorials/tasks/source-separation.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fabIezbT125I"
   },
   "source": [
    "# Source Separation\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In source separation, the goal is to be able to separate out the sources from an observed mixture signal which consists of superposition of several sources. Let us demonstrate this with an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2ZyBlnjRvetT"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "T = 1000\n",
    "t = np.arange(0, T)\n",
    "fs = 3000\n",
    "f0 = 10\n",
    "\n",
    "source1 = np.sin(2*np.pi*(f0/fs)*t) + 0.1*np.random.randn(T)\n",
    "source2 = np.sin(2*np.pi*(3*f0/fs)*t)+ 0.1*np.random.randn(T)\n",
    "mixture = source1 + source2\n",
    "\n",
    "plt.subplot(311)\n",
    "plt.plot(source1)\n",
    "plt.title('Source 1')\n",
    "plt.xticks(np.arange(0, 100, T), '')\n",
    "\n",
    "plt.subplot(312)\n",
    "plt.plot(source2)\n",
    "plt.title('Source 2')\n",
    "plt.xticks(np.arange(0, 100, T), '')\n",
    "\n",
    "plt.subplot(313)\n",
    "plt.plot(mixture)\n",
    "plt.title('Mixture')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tiofqq4zHuXW"
   },
   "source": [
    "The goal is to get Source 1 and Source 2 from the Mixture Signal. In our case Source 1 is a noisy sinosoid with frequency f0, and Source is a noisy sinusoid with frequency 3*f0.  \n",
    "\n",
    "## A toy example\n",
    "\n",
    "Now, let's consider a slightly more interesting case where, source 1 is a sinusoid with a random frequency smaller than f_threshold, and source 2 is a sinusoid with frequency larger than f_threshold. Let's first build the dataset and and the dataloaders using speechbrain. We will then build a model which will able to separate out the sources successfully."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D6TzkU5NR7s9"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.data as data_utils\n",
    "import librosa.display as lrd\n",
    "\n",
    "N = 100\n",
    "f_th = 200\n",
    "fs = 8000\n",
    "\n",
    "T = 10000\n",
    "t = torch.arange(0, T).unsqueeze(0)\n",
    "f1 = torch.randint(5, f_th, (N, 1))\n",
    "f2 = torch.randint(f_th, 400, (N, 1))\n",
    "batch_size = 10\n",
    "\n",
    "source1 = torch.sin(2*np.pi*(f1/fs)*t)\n",
    "source2 = torch.sin(2*np.pi*(f2/fs)*t)\n",
    "mixture = source1 + source2\n",
    "N_train = 90\n",
    "train_dataset = data_utils.TensorDataset(source1[:N_train], source2[:N_train], mixture[:N_train])\n",
    "test_dataset = data_utils.TensorDataset(source1[N_train:], source2[N_train:], mixture[N_train:])\n",
    "\n",
    "train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size)\n",
    "test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size)\n",
    "\n",
    "# now let's visualize the frequency spectra for the dataset\n",
    "fft_size = 200\n",
    "\n",
    "plt.figure(figsize=[20, 10], dpi=50)\n",
    "\n",
    "plt.subplot(131)\n",
    "mix_gt = mixture[N_train]\n",
    "mix_spec = torch.sqrt((torch.view_as_real(torch.stft(mix_gt, n_fft=fft_size, return_complex=True))**2).sum(-1))\n",
    "lrd.specshow(mix_spec.numpy(), y_axis='log')\n",
    "plt.title('Mixture Spectrogram')\n",
    "\n",
    "plt.subplot(132)\n",
    "source1_gt = source1[N_train]\n",
    "source1_spec = torch.sqrt((torch.view_as_real(torch.stft(source1_gt, n_fft=fft_size, return_complex=True))**2).sum(-1))\n",
    "lrd.specshow(source1_spec.numpy(), y_axis='log')\n",
    "plt.title('Source 1 Spectrogram')\n",
    "\n",
    "plt.subplot(133)\n",
    "source2_gt = source2[N_train]\n",
    "source2_spec = torch.sqrt((torch.view_as_real(torch.stft(source2_gt, n_fft=fft_size, return_complex=True))**2).sum(-1))\n",
    "lrd.specshow(source2_spec.numpy(), y_axis='log')\n",
    "plt.title('Source 2 Spectrogram')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CeG7ghFQr589"
   },
   "source": [
    "Now that we created the dataset, we can now focus on building a model would be able to recover the original sources from the mixture signal. For this, we will use speechbrain. Let us first install speechbrain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1h4PPdSR7YJd"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "# Installing SpeechBrain via pip\n",
    "BRANCH = 'develop'\n",
    "!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hUqDGcXt7nlg"
   },
   "source": [
    "Now, let us construct a simple model with pytorch and speechbrain for source separation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WPRLjPWO_9k2"
   },
   "outputs": [],
   "source": [
    "import speechbrain as sb\n",
    "import torch.nn as nn\n",
    "\n",
    "# define the model\n",
    "class simpleseparator(nn.Module):\n",
    "  def __init__(self, fft_size, hidden_size, num_sources=2):\n",
    "    super(simpleseparator, self).__init__()\n",
    "    self.masking = nn.LSTM(input_size=fft_size//2 + 1, hidden_size=hidden_size, batch_first=True, bidirectional=True)\n",
    "    self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=num_sources*(fft_size//2 + 1))\n",
    "    self.fft_size=fft_size\n",
    "    self.num_sources = num_sources\n",
    "\n",
    "  def forward(self, inp):\n",
    "    # batch x freq x time x realim\n",
    "    y = torch.view_as_real(torch.stft(inp, n_fft=self.fft_size, return_complex=True))\n",
    "\n",
    "    # batch X freq x time\n",
    "    mag = torch.sqrt((y ** 2).sum(-1))\n",
    "    phase = torch.atan2(y[:, :, :, 1], y[:, :, :, 0])\n",
    "\n",
    "    # batch x time x freq\n",
    "    mag = mag.permute(0, 2, 1)\n",
    "\n",
    "    # batch x time x feature\n",
    "    rnn_out = self.masking(mag)[0]\n",
    "\n",
    "    # batch x time x (nfft*num_sources)\n",
    "    lin_out = self.output_layer(rnn_out)\n",
    "\n",
    "    # batch x time x nfft x num_sources\n",
    "    lin_out = nn.functional.relu(lin_out.reshape(lin_out.size(0), lin_out.size(1), -1, self.num_sources))\n",
    "\n",
    "    # reconstruct in time domain\n",
    "    sources = []\n",
    "    all_masks = []\n",
    "    for n in range(self.num_sources):\n",
    "      sourcehat_mask = (lin_out[:, :, :, n])\n",
    "      all_masks.append(sourcehat_mask)\n",
    "\n",
    "      # multiply with mask and magnitude\n",
    "      sourcehat_dft = (sourcehat_mask * mag).permute(0, 2, 1) * torch.exp(1j * phase)\n",
    "\n",
    "      # reconstruct in time domain with istft\n",
    "      sourcehat = torch.istft(sourcehat_dft, n_fft=self.fft_size)\n",
    "      sources.append(sourcehat)\n",
    "    return sources, all_masks, mag\n",
    "\n",
    "# test_forwardpass\n",
    "model = simpleseparator(fft_size=fft_size, hidden_size=300)\n",
    "est_sources, _, _ = model.forward(mixture[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7KQ1ZKoyfM82"
   },
   "source": [
    "Now that our model, we can now write the Brain class for training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "U78x9qR3j3kN"
   },
   "outputs": [],
   "source": [
    "class SeparationBrain(sb.Brain):\n",
    "    def __init__(self, train_loss, modules, opt_class):\n",
    "        super(SeparationBrain, self).__init__(modules=modules, opt_class=opt_class)\n",
    "        self.train_loss = train_loss\n",
    "\n",
    "    def compute_forward(self, mix):\n",
    "        \"\"\"Forward computations from the mixture to the separated signals.\"\"\"\n",
    "\n",
    "        # Get the estimates for the sources\n",
    "        est_sources, _, _ = self.modules.mdl(mix)\n",
    "\n",
    "        est_sources = torch.stack(est_sources, dim=-1)\n",
    "\n",
    "        # T changed after conv1d in encoder, fix it here\n",
    "        T_origin = mix.size(1)\n",
    "        T_est = est_sources.size(1)\n",
    "        if T_origin > T_est:\n",
    "            est_sources = nn.functional.pad(est_sources, (0, 0, 0, T_origin - T_est))\n",
    "        else:\n",
    "            est_sources = est_sources[:, :T_origin, :]\n",
    "\n",
    "        return est_sources\n",
    "\n",
    "    def compute_objectives(self, targets, est_sources):\n",
    "        \"\"\"Computes the loss functions between estimated and ground truth sources\"\"\"\n",
    "        if self.train_loss == 'l1':\n",
    "          return (est_sources - targets).abs().mean()\n",
    "        elif self.train_loss == 'si-snr':\n",
    "          return sb.nnet.losses.get_si_snr_with_pitwrapper(targets, est_sources).mean()\n",
    "\n",
    "\n",
    "    def fit_batch(self, batch):\n",
    "        \"\"\"Trains one batch\"\"\"\n",
    "        # Unpacking batch list\n",
    "        source1, source2, mix = batch\n",
    "        targets = torch.stack([source1, source2], dim=-1)\n",
    "\n",
    "        est_sources = self.compute_forward(mix)\n",
    "        loss = self.compute_objectives(targets, est_sources)\n",
    "\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        self.optimizer.zero_grad()\n",
    "        return loss.detach().cpu()\n",
    "\n",
    "    def evaluate_batch(self, batch, stage):\n",
    "        \"\"\"Computations needed for test batches\"\"\"\n",
    "\n",
    "        source1, source2, mix = batch\n",
    "        targets = torch.stack([source1, source2], dim=-1)\n",
    "\n",
    "        est_sources = self.compute_forward(mix)\n",
    "\n",
    "        si_snr = sb.nnet.losses.get_si_snr_with_pitwrapper(targets, est_sources)\n",
    "        si_snr_mean = si_snr.mean().item()\n",
    "        print('VALID SI-SNR = {}'.format(-si_snr_mean))\n",
    "        return si_snr.mean().detach()\n",
    "\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "optimizer = lambda x: torch.optim.Adam(x, lr=0.0001)\n",
    "N_epochs = 10\n",
    "epoch_counter = sb.utils.epoch_loop.EpochCounter(limit=N_epochs)\n",
    "\n",
    "separator = SeparationBrain(\n",
    "        train_loss='l1',\n",
    "        modules={'mdl': model},\n",
    "        opt_class=optimizer\n",
    "\n",
    "    )\n",
    "\n",
    "\n",
    "separator.fit(\n",
    "            epoch_counter,\n",
    "            train_loader,\n",
    "            test_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xEI7cJw00UYU"
   },
   "source": [
    "Now, let's visualize the results. For this purpose let's first install librosa. It has a nice tool for visualizing spectrograms.  \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "I0ul7PjF9HwC"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install librosa\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "69mRj3DKAa4F"
   },
   "source": [
    "We will first plot the the spectra for the ground truth sources. And then we will run a forward pass with the model and plot the estimated sources."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j6NbKy4S1yl_"
   },
   "outputs": [],
   "source": [
    "estimated_sources, all_masks, mag = separator.modules.mdl.forward(mixture[N_train:])\n",
    "\n",
    "\n",
    "plt.figure(figsize=[20, 10], dpi=80)\n",
    "\n",
    "plt.subplot(331)\n",
    "mag = mag[0].t().numpy()\n",
    "lrd.specshow(mag, y_axis='log')\n",
    "plt.title('Mixture')\n",
    "\n",
    "plt.subplot(334)\n",
    "mask1 = all_masks[0][0].detach().t().numpy()\n",
    "lrd.specshow(mask1, y_axis='log')\n",
    "plt.title('Mask for source 1')\n",
    "\n",
    "plt.subplot(335)\n",
    "masked1 = mask1 * mag\n",
    "lrd.specshow(masked1, y_axis='log')\n",
    "plt.title('Estimated Source 1')\n",
    "\n",
    "plt.subplot(336)\n",
    "source1_gt = source1[N_train]\n",
    "source1_spec = torch.sqrt((torch.view_as_real(torch.stft(source1_gt, n_fft=fft_size, return_complex=True))**2).sum(-1))\n",
    "lrd.specshow(source1_spec.numpy(), y_axis='log')\n",
    "plt.title('Ground Truth Source 1')\n",
    "\n",
    "plt.subplot(337)\n",
    "mask2 = all_masks[1][0].detach().t().numpy()\n",
    "lrd.specshow(mask2, y_axis='log')\n",
    "plt.title('Mask for Source 2')\n",
    "\n",
    "plt.subplot(338)\n",
    "masked2 = mask2 * mag\n",
    "lrd.specshow(masked2, y_axis='log')\n",
    "plt.title('Estimated Source 2')\n",
    "\n",
    "plt.subplot(339)\n",
    "source2_gt = source2[N_train]\n",
    "source2_spec = torch.sqrt((torch.view_as_real(torch.stft(source2_gt, n_fft=fft_size, return_complex=True)**2)).sum(-1))\n",
    "lrd.specshow(source2_spec.numpy(), y_axis='log')\n",
    "plt.title('Ground Truth Source 2')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sBmdv0-jq8yR"
   },
   "source": [
    "Notice that these masks are basically band stop filters which aims to remove the interferences from the other source.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "* Train the same model with SI-SNR loss and observe if this helps to improve the performance.\n",
    "* Replace the STFT front end, and the ISTFT reconstruction with convolutional layer, and transposed convolution layers. Do the same visualization above, also visualize the filters learnt by the convolutional front end, and the reconstruction layer, and compare it with the DFT bases.  \n",
    "\n",
    "\n",
    "## A sound source separation example with a pre-existing model from speechbrain\n",
    "\n",
    "First, let's download the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NUQX-saDsjPF"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AADx5I8oV0IdekCf80MSkxMia/mixture_0.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AAAZI7ZezKyHFGPdus6hn2v_a/mixture_1.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AACh4Yy4H-Ii2I0mr_b1lQdXa/mixture_2.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AAAenTlEsoj1-AGbCxeJfMHoa/mixture_3.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AAC-awQo-9NFVVULuVwaHKKWa/source1_0.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AABVKWtdVhXZE6Voq1I_c6g5a/source1_1.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AAC9EfjTTwL0dscH16waP9s-a/source1_2.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AAC5Ozb4rS9qby268JSIy5Uwa/source1_3.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AABlonG910Ms2l-rTN5ct3Oka/source2_0.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AACDOqEgyXIeA2r1Rkf7VgQTa/source2_1.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AACTYGAG0LOh6HvxpVYoqO_Da/source2_2.wav\n",
    "!wget https://www.dropbox.com/sh/07vwpwru6qo6yhf/AACPmq-ZJNzfh4bnO34_8mfAa/source2_3.wav"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "84g08IFDmzbo"
   },
   "source": [
    "Now let's first listen to these sounds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vgriZkbNmssS"
   },
   "outputs": [],
   "source": [
    "import speechbrain\n",
    "from speechbrain.dataio.dataio import read_audio\n",
    "from IPython.display import Audio\n",
    "\n",
    "mixture_0 = read_audio('mixture_0.wav').squeeze()\n",
    "source1_0 = read_audio('source1_0.wav').squeeze()\n",
    "source2_0 = read_audio('source2_0.wav').squeeze()\n",
    "\n",
    "mixture_1 = read_audio('mixture_1.wav').squeeze()\n",
    "source1_1 = read_audio('source1_1.wav').squeeze()\n",
    "source2_1 = read_audio('source2_1.wav').squeeze()\n",
    "\n",
    "mixture_2 = read_audio('mixture_2.wav').squeeze()\n",
    "source1_2 = read_audio('source1_2.wav').squeeze()\n",
    "source2_2 = read_audio('source2_2.wav').squeeze()\n",
    "\n",
    "mixture_3 = read_audio('mixture_3.wav').squeeze()\n",
    "source1_3 = read_audio('source1_3.wav').squeeze()\n",
    "source2_3 = read_audio('source2_3.wav').squeeze()\n",
    "\n",
    "train_mixs = [mixture_0, mixture_1, mixture_2]\n",
    "train_source1s = [source1_0, source1_1, source1_2]\n",
    "train_source2s = [source2_0, source2_1, source2_2]\n",
    "\n",
    "Audio(mixture_0, rate=16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AlRgO_rT33GB"
   },
   "outputs": [],
   "source": [
    "Audio(source1_0, rate=16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "C9HzYrVl3uNc"
   },
   "outputs": [],
   "source": [
    "Audio(source2_0, rate=16000)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8NeccEWR51Px"
   },
   "source": [
    "Now, let's construct the datasets and dataloaders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UqDjKaUG5z9S"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "\n",
    "class source_separation_dataset(Dataset):\n",
    "    def __init__(self, train_mixs, train_source1s, train_source2s):\n",
    "        self.mixs = train_mixs\n",
    "        self.train_source1s = train_source1s\n",
    "        self.train_source2s = train_source2s\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.mixs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        mix = self.mixs[idx]\n",
    "        source1 = self.train_source1s[idx]\n",
    "        source2 = self.train_source2s[idx]\n",
    "        return mix, source1, source2\n",
    "\n",
    "train_dataset_audio = source_separation_dataset(train_mixs, train_source1s, train_source2s)\n",
    "valid_dataset_audio = source_separation_dataset([mixture_2], [source1_2], [source2_2])\n",
    "\n",
    "train_loader_audio = DataLoader(train_dataset_audio, batch_size=1)\n",
    "valid_loader_audio = DataLoader(valid_dataset_audio, batch_size=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K1FeHiQrh7p3"
   },
   "source": [
    "And now, let's tinker the model we constructed and use it on this small dataset. For this purpose we will use the mask-based end-to-end architecture:\n",
    "\n",
    "![end2end.png]()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EWW80iPjh4DJ"
   },
   "outputs": [],
   "source": [
    "fft_size=1024\n",
    "model_audio = simpleseparator(fft_size=fft_size, hidden_size=300)\n",
    "\n",
    "\n",
    "optimizer = lambda x: torch.optim.Adam(x, lr=0.0005)\n",
    "N_epochs = 100\n",
    "epoch_counter = sb.utils.epoch_loop.EpochCounter(limit=N_epochs)\n",
    "\n",
    "separator = SeparationBrain(\n",
    "        train_loss='si-snr',\n",
    "        modules={'mdl': model_audio},\n",
    "        opt_class=optimizer\n",
    "\n",
    "    )\n",
    "\n",
    "\n",
    "separator.fit(\n",
    "            epoch_counter,\n",
    "            train_loader_audio,\n",
    "            valid_loader_audio)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9C2EQ5M20QsJ"
   },
   "outputs": [],
   "source": [
    "class audioseparator(nn.Module):\n",
    "  def __init__(self, fft_size, hidden_size, num_sources=2, kernel_size=16):\n",
    "    super(audioseparator, self).__init__()\n",
    "    self.encoder = nn.Conv1d(in_channels=1, out_channels=fft_size, kernel_size=16, stride=kernel_size//2)\n",
    "\n",
    "    # MaskNet\n",
    "    self.rnn = nn.LSTM(input_size=fft_size, hidden_size=hidden_size, batch_first=True, bidirectional=True)\n",
    "    self.output_layer = nn.Linear(in_features=hidden_size*2, out_features=num_sources*(fft_size))\n",
    "\n",
    "    self.decoder = nn.ConvTranspose1d(in_channels=fft_size, out_channels=1, kernel_size=kernel_size, stride=kernel_size//2)\n",
    "\n",
    "    self.fft_size = fft_size\n",
    "    self.hidden_size = hidden_size\n",
    "    self.num_sources = num_sources\n",
    "\n",
    "  def forward(self, inp):\n",
    "    # batch x channels x time\n",
    "    y = nn.functional.relu(self.encoder(inp.unsqueeze(0)))\n",
    "\n",
    "    # batch x time x nfft\n",
    "    y = y.permute(0, 2, 1)\n",
    "\n",
    "    # batch x time x feature\n",
    "    rnn_out = self.rnn(y)[0]\n",
    "\n",
    "    # batch x time x (nfft*num_sources)\n",
    "    lin_out = self.output_layer(rnn_out)\n",
    "\n",
    "    # batch x time x nfft x num_sources\n",
    "    lin_out = lin_out.reshape(lin_out.size(0), lin_out.size(1), -1, self.num_sources)\n",
    "\n",
    "    # reconstruct in time domain\n",
    "    sources = []\n",
    "    all_masks = []\n",
    "    for n in range(self.num_sources):\n",
    "      sourcehat_mask = nn.functional.relu(lin_out[:, :, :, n])\n",
    "      all_masks.append(sourcehat_mask)\n",
    "\n",
    "      # multiply with mask and magnitude\n",
    "      T = sourcehat_mask.size(1)\n",
    "      sourcehat_latent = (sourcehat_mask * y[:, :T, :]).permute(0, 2, 1)\n",
    "\n",
    "      # reconstruct in time domain with istft\n",
    "      sourcehat = self.decoder(sourcehat_latent).squeeze(0)\n",
    "      sources.append(sourcehat)\n",
    "\n",
    "    return sources, all_masks, y\n",
    "\n",
    "model_audio = audioseparator(fft_size=fft_size, hidden_size=300, kernel_size=256)\n",
    "out, _, _ = model_audio.forward(mixture_0.unsqueeze(0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s5dkjj0P5wLu"
   },
   "outputs": [],
   "source": [
    "optimizer = lambda x: torch.optim.Adam(x, lr=0.0002)\n",
    "N_epochs = 200\n",
    "epoch_counter = sb.utils.epoch_loop.EpochCounter(limit=N_epochs)\n",
    "\n",
    "separator = SeparationBrain(\n",
    "        train_loss='si-snr',\n",
    "        modules={'mdl': model_audio},\n",
    "        opt_class=optimizer\n",
    "\n",
    "    )\n",
    "\n",
    "separator.fit(\n",
    "            epoch_counter,\n",
    "            train_loader_audio,\n",
    "            valid_loader_audio)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5JUqqmL_KFGs"
   },
   "outputs": [],
   "source": [
    "estimated_sources_test, all_masks, mag = model_audio.forward(mixture_3.unsqueeze(0))\n",
    "estimated_sources_train, all_masks, mag = model_audio.forward(mixture_0.unsqueeze(0))\n",
    "\n",
    "\n",
    "Audio(estimated_sources_test[0].squeeze().detach(), rate=16000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2trZ6WX_kAsT"
   },
   "outputs": [],
   "source": [
    "Audio(estimated_sources_test[1].squeeze().detach(), rate=16000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nzvsrhsEm4f1"
   },
   "outputs": [],
   "source": [
    "Audio(estimated_sources_train[0].squeeze().detach(), rate=16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aAZOTrYDncYm"
   },
   "outputs": [],
   "source": [
    "Audio(estimated_sources_train[1].squeeze().detach(), rate=16000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yUCKuExTnz4Z"
   },
   "source": [
    "It does not work that great because of the introduced artifacts, but we can hear that it supresses the interferences."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sb_auto_footer",
    "tags": [
     "sb_auto_footer"
    ]
   },
   "source": [
    "## Citing SpeechBrain\n",
    "\n",
    "If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:\n",
    "\n",
    "```bibtex\n",
    "@misc{speechbrainV1,\n",
    "  title={Open-Source Conversational AI with {SpeechBrain} 1.0},\n",
    "  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},\n",
    "  year={2024},\n",
    "  eprint={2407.00463},\n",
    "  archivePrefix={arXiv},\n",
    "  primaryClass={cs.LG},\n",
    "  url={https://arxiv.org/abs/2407.00463},\n",
    "}\n",
    "@misc{speechbrain,\n",
    "  title={{SpeechBrain}: A General-Purpose Speech Toolkit},\n",
    "  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},\n",
    "  year={2021},\n",
    "  eprint={2106.04624},\n",
    "  archivePrefix={arXiv},\n",
    "  primaryClass={eess.AS},\n",
    "  note={arXiv:2106.04624}\n",
    "}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [
    {
     "file_id": "1lcPtycPQOTVanhxj6GiOBnWuWtzXTGJw",
     "timestamp": 1613147076800
    },
    {
     "file_id": "18rvXsEWzSeGVXcrXB_AVIMnMzpUo0irv",
     "timestamp": 1613140070220
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
